Loss function: Earth movel distance - The effort needed to make both distributions equal
- Critics(discriminator) values not restricted to be between 0 and 1
- Even for very different distributions, gradients are significant and high enough to drive the process in the right way
In [ ]:
import torch
import torchvision
import os
import PIL
import pdb
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import make_grid
from tqdm.auto import tqdm
from PIL import Image
In [ ]:
# OPTIONAL
!pip install wandb -qqq
import wandb
wandb.login(key='9e59c0e5f929ee5a223fead436cba53be4f90e1a')
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.9/6.9 MB 29.3 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 207.3/207.3 kB 27.7 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 289.6/289.6 kB 31.5 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 62.7/62.7 kB 9.6 MB/s eta 0:00:00
wandb: W&B API key is configured. Use `wandb login --relogin` to force relogin wandb: WARNING If you're specifying your api key in code, ensure this code is not shared publicly. wandb: WARNING Consider setting the WANDB_API_KEY environment variable, or running `wandb login` from the command line. wandb: Appending key for api.wandb.ai to your netrc file: /root/.netrc
Out[ ]:
True
In [ ]:
def show(tensor, num=25, wandbactivation=0, name=''):
data = tensor.detach().cpu()
grid = make_grid(data[:num], nrow=5).permute(1,2,0)
# optional
# wandb - online activation
if (wandbactivation==1 and wandbact==1):
wandb.log({name:wandb.Image(grid.numpy().clip(0,1))})
# cliping pixels to range(0,1)
plt.imshow(grid.clip(0,1))
plt.show()
In [ ]:
## hyperparameters and general parameters
n_epochs = 1000
batch_size = 128
lr = 1e-4
# z_dim - input noise latent vector dim
z_dim = 200
device = 'cuda'
cur_step = 0
# 5 cycles training of the critic, then 1 of the generator
# generally critic needs more training than the generator
crit_cycles = 5
gen_losses = []
crit_losses = []
show_step = 35
save_step = 35
# optional, tracking stats online
wandbact = 1
In [ ]:
# optional wandb
%%capture
# experiment_name = wandb.util.generate_id()
experiment_name = 'MY EXP'
myrun = wandb.init(
project='wgan',
name=experiment_name,
group=experiment_name,
config={'optimizer':'adam',
'model':'wgan gp',
'epoch':'1000',
'batch_size':128
}
)
config = wandb.config
In [ ]:
# optional wandb
print(experiment_name)
In [ ]:
# generator model
class Generator(nn.Module):
# d_dim - internal dimension for the output of the convolutional layers
def __init__(self, z_dim=64, d_dim=16):
super(Generator, self).__init__()
self.z_dim = z_dim
self.gen = nn.Sequential(
# ConvTranspose2d: in_channels, out_channels, kernel_size, stride=1, padding=0
# n - width or height
# (n - 1) * stride - 2 * padding + kernel_size
# generator starts 1x1 pixel and z_dim number of channels and it gives it dimensionality of latent space
# starting with 200 channels bringing up channels to 512 and increasing size of the image
# 1st block
nn.ConvTranspose2d(in_channels=z_dim, out_channels=d_dim*32, kernel_size=4, stride=1, padding=0),
# normalizing values for improving stability, num_features = out from conv layer
nn.BatchNorm2d(num_features=d_dim*32),
# applying nonlinearity
nn.ReLU(True),
# 2nd block
nn.ConvTranspose2d(in_channels=d_dim*32, out_channels=d_dim*16, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(d_dim*16),
nn.ReLU(inplace=True),
# 3rd block
nn.ConvTranspose2d(in_channels=d_dim*16, out_channels=d_dim*8, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(d_dim*8),
nn.ReLU(inplace=True),
# 4th block
nn.ConvTranspose2d(in_channels=d_dim*8, out_channels=d_dim*4, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(d_dim*4),
nn.ReLU(inplace=True),
# 5th block
nn.ConvTranspose2d(in_channels=d_dim*4, out_channels=d_dim*2, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(d_dim*2),
nn.ReLU(inplace=True),
# 6th block
# in output layer we stay with ONLY 3 channels
nn.ConvTranspose2d(in_channels=d_dim*2, out_channels=3, kernel_size=4, stride=2, padding=1),
# no batch norm
# results from -1 to 1
nn.Tanh()
)
def forward(self, noise):
# generator recives noise as input
x = noise.view(len(noise), self.z_dim, 1, 1) # 128(batch) x 200(z_dim) x 1(height) x 1(width)
return self.gen(x)
def gen_noise(num, z_dim, device='cuda'):
return torch.randn(num, z_dim, device=device) # 128 x 200
# n - width or height
# nn.Conv2d: (n + 2 * pad - ks) // stride + 1
# nn.ConvTranspose2d: (n - 1) * stride - 2 * padding + kernel_size
(n - 1) * stride - 2 * padding + kernel_size
- 1st step:
- (1 - 1) * 1 - 2 * 0 + 4 = 4x4 image, channels: 200 to 512
- 2nd step
- (4 - 1) * 2 - 2 * 1 + 4 = 8x8 image, channels: 512 to 256
- 3rd step
- (8 - 1) * 2 - 2 * 1 + 4 = 16x16 image, channels: 256 to 128
- 4th step
- (16 - 1) * 2 - 2 * 1 + 4 = 32x32 image, channels: 128 to 64
- 5th step
- (32 - 1) * 2 - 2 * 1 + 4 = 64x64 image, channels: 64 to 32
- 6th step
- (64 - 1) * 2 - 2 * 1 + 4 = 128x128 image, channels: 32 to 3
In [ ]:
# critic model
# Conv2d: in_channels, out_channels, kernel_size, stride=1, padding=0
# (n + 2 * padding - kernel_size) // stride + 1
class Critic(nn.Module):
def __init__(self, d_dim=16):
super(Critic, self).__init__()
self.crit = nn.Sequential(
# 1st block
nn.Conv2d(in_channels=3, out_channels=d_dim, kernel_size=4, stride=2, padding=1),
# instead of batchnorm2d, normalizing according to the values of the whole instance insted of values of the batch
nn.InstanceNorm2d(d_dim), # works the best
# leaky keeps information, negative values have little slope(small negative numbers), but theyre not converted to 0
nn.LeakyReLU(0.2),
# 2nd block
nn.Conv2d(in_channels=d_dim, out_channels=d_dim*2, kernel_size=4, stride=2, padding=1),
nn.InstanceNorm2d(d_dim*2),
nn.LeakyReLU(0.2),
# 3rd block
nn.Conv2d(in_channels=d_dim*2, out_channels=d_dim*4, kernel_size=4, stride=2, padding=1),
nn.InstanceNorm2d(d_dim*4),
nn.LeakyReLU(0.2),
# 4th block
nn.Conv2d(in_channels=d_dim*4, out_channels=d_dim*8, kernel_size=4, stride=2, padding=1),
nn.InstanceNorm2d(d_dim*8),
nn.LeakyReLU(0.2),
# 5th block
nn.Conv2d(in_channels=d_dim*8, out_channels=d_dim*16, kernel_size=4, stride=2, padding=1),
nn.InstanceNorm2d(d_dim*16),
nn.LeakyReLU(0.2),
# 6th block
# we return 1 value, either its fake or real
# stride has to be 1 and padding 0
nn.Conv2d(in_channels=d_dim*16, out_channels=1, kernel_size=4, stride=1, padding=0),
)
def forward(self, image):
# image: 128(batch) x 3(channels) x 128(height) x 128(width)
crit_pred = self.crit(image) # 128 x 1 x 1 x 1
return crit_pred.view(len(crit_pred), -1) # 128(batch values) x 1(fake or real)
(n + 2 * padding - kernel_size) // stride + 1
- 1st step
- (128 + 2 * 1 - 4) //2 + 1 = 64x64 image, channels: 3 to 16
- 2nd step
- (64 + 2 * 1 - 4) //2 + 1 = 32x32 image, channels: 16 to 32
- 3rd step
- (32 + 2 * 1 - 4) //2 + 1 = 16x16 image, channels: 32 to 64
- 4th step
- (16 + 2 * 1 - 4) //2 + 1 = 8x8 image, channels: 64 to 128
- 5th step
- (8 + 2 * 1 - 4) //2 + 1 = 4x4 image, channels: 128 to 256
- 6th step
- (4 + 2 * 0 - 4) //1 + 1 = 1x1 image, channels: 256 to 1
Alternative way to initialize parameters
In [ ]:
def init_weights(m):
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
torch.nn.init.normal(m.weight, 0.0, 0.02)
torch.nn.init.constant(m.bias, 0)
if isinstance(m, nn.BatchNorm2d):
torch.nn.init.normal(m.weight, 0.0, 0.02)
torch.nn.init.constant(m.bias, 0)
# Initializations - NOT HERE
# gen = gen.apply(init_weights)
# crit = crit.apply(init_weights)
In [ ]:
# loading dataset
import gdown, zipfile
# url = ''
path = '/content/drive/MyDrive'
download_path = f'{path}/img_align_celeba.zip'
if not os.path.exists(path):
os.makedirs(path)
# gdown.download(url, download_path, quiet=False)
with zipfile.ZipFile(download_path, 'r') as ziphandler:
ziphandler.extractall('.')
In [ ]:
!wget https://drive.google.com/file/d/0B7EVK8r0v71pZjFTYXZWM3FlRnM/view?resourcekey=0-dYn9z10tMJOBAkviAcfdyQ
!unzip -q img_align_celeba.zip
--2024-06-17 10:52:45-- https://drive.google.com/file/d/0B7EVK8r0v71pZjFTYXZWM3FlRnM/view?resourcekey=0-dYn9z10tMJOBAkviAcfdyQ Resolving drive.google.com (drive.google.com)... 108.177.121.138, 108.177.121.100, 108.177.121.101, ... Connecting to drive.google.com (drive.google.com)|108.177.121.138|:443... connected. HTTP request sent, awaiting response... 200 OK Length: unspecified [text/html] Saving to: ‘view?resourcekey=0-dYn9z10tMJOBAkviAcfdyQ’ view?reso [<=> ] 0 --.-KB/s view?resourcekey=0- [ <=> ] 88.03K --.-KB/s in 0.003s 2024-06-17 10:52:45 (26.0 MB/s) - ‘view?resourcekey=0-dYn9z10tMJOBAkviAcfdyQ’ saved [90138] unzip: cannot find or open img_align_celeba.zip, img_align_celeba.zip.zip or img_align_celeba.zip.ZIP.
In [ ]:
# class Dataset(Dataset):
class Dataset(Dataset):
def __init__(self, path, size=128, lim=10000):
self.sizes = [size, size]
# paths to the images
items, labels = [], []
for data in os.listdir(path)[:lim]:
# path: './data/celeba/img_align_celeba'
# data: '123213.img'
item = os.path.join(path, data)
items.append(item)
labels.append(data)
self.items=items
self.labels=labels
def __len__(self):
return len(self.items)
def __getitem__(self, idx):
# open image idx
data = PIL.Image.open(self.items[idx]).convert('RGB') # size of the image: fe. 1278, 121
data = np.asarray(torchvision.transforms.Resize(self.sizes)(data)) # 128 x 128 x 3
data = np.transpose(data, (2,0,1)).astype(np.float32, copy=False) # 3 x 128 x 128
# from np to tensor for training, div = standarizing
data = torch.from_numpy(data).div(255) # leaving values from 0 to 1
return data, self.labels[idx]
In [ ]:
len(os.listdir('./img_align_celeba'))
Out[ ]:
202599
In [ ]:
# Dataset
data_path = './img_align_celeba'
ds = Dataset(data_path, size=128, lim=50000)
# DataLoader
dataloader = DataLoader(dataset=ds, batch_size=128, shuffle=True)
# Models
gen = Generator(z_dim=z_dim).to(device)
crit = Critic().to(device)
# Optimizers
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr, betas=(0.5, 0.9)) # betas - internal calculations, works well with this architecture
crit_opt = torch.optim.Adam(crit.parameters(), lr=lr, betas=(0.5, 0.9))
# Initializations - OPTIONAL
# gen = gen.apply(init_weights)
# crit = crit.apply(init_weights)
# wandb - OPTIONAL
if (wandbact==1):
wandb.watch(gen, log_freq=100)
wandb.watch(crit, log_freq=100)
x, y = next(iter(dataloader))
show(x)
In [ ]:
# gradient penalty calculation
def get_gp(real, fake, crit, alpha, gamma=10): # alpha does random interpolations, gamma stands for intensity of gp regularizations
mix_images = real * alpha + fake * (1-alpha) # 128(batch) x 3 x 128 x 128, linear interpolation
mix_scores = crit(mix_images) # predictions: 128(batch) x 1
# we want to penalize gradients that are to large
# computing and returning the sum of the gradients of the outputs with respect to the inputs
gradient = torch.autograd.grad(
inputs = mix_images,
outputs = mix_scores,
# puting ones to take into account all the grades and outputs
grad_outputs = torch.ones_like(mix_scores),
retain_graph=True,
create_graph=True,
)[0] # return first batch 128(bs) x 3 x 128 x 128
gradient = gradient.view(len(gradient), -1) # 128 x 49512(128x128x3)
gradient_norm = gradient.norm(2, dim=1) # L2 norm
gp = gamma * ((gradient_norm - 1)**2).mean()
return gp
In [ ]:
# saving and loading checkpoints
if not os.path.exists('./content/drive/MyDrive/training_data/'):
os.makedirs('./content/drive/MyDrive/training_data/')
root_path='./content/drive/MyDrive/training_data/'
def save_checkpoint(name):
torch.save({
'epoch':epoch,
'model_state_dict':gen.state_dict(),
'optimizer_state_dict':gen_opt.state_dict(),
}, f'{root_path}G-{name}.pkl')
torch.save({
'epoch':epoch,
'model_state_dict':crit.state_dict(),
'optimizer_state_dict':crit_opt.state_dict(),
}, f'{root_path}C-{name}.pkl')
print('Saved checkpoint')
def load_checkpoint(name):
# generator
# loading file
checkpoint = torch.load(f'{root_path}G-{name}.pkl')
# loading values to the model
gen.load_state_dict(checkpoint['model_state_dict'])
gen_opt.load_state_dict(checkpoint['optimizer_state_dict'])
# critic
checkpoint = torch.load(f'{root_path}C-{name}.pkl')
crit.load_state_dict(checkpoint['model_state_dict'])
crit_opt.load_state_dict(checkpoint['optimizer_state_dict'])
print('Loaded checkpoint')
--------------------------------------------------------------------------- NameError Traceback (most recent call last) <ipython-input-1-39c5fc667a2d> in <cell line: 3>() 1 # saving and loading checkpoints 2 ----> 3 if not os.path.exists('./content/drive/MyDrive/training_data/'): 4 os.makedirs('./content/drive/MyDrive/training_data/') 5 root_path='./content/drive/MyDrive/training_data/' NameError: name 'os' is not defined
In [ ]:
#!cp C-final* ./data/
#!cp G-final* ./data/
epoch=1
save_checkpoint('test')
load_checkpoint('test')
Saved checkpoint Loaded checkpoint
In [ ]:
# Training loop
for epoch in range(n_epochs):
for real, _ in tqdm(dataloader):
cur_bs = len(real) # 128
real = real.to(device)
## Critic
mean_crit_loss = 0
for _ in range(crit_cycles):
# zeroing gradient of the optimizer
crit_opt.zero_grad()
noise = gen_noise(cur_bs, z_dim)
fake = gen(noise)
# detaching for not affecting the parameters of the generator
crit_fake_pred = crit(fake.detach())
crit_real_pred = crit(real)
# alpha vector (numbers size of the batch)
alpha = torch.rand(len(real), 1, 1, 1, device=device, requires_grad=True) # 128 x 1 x 1 x 1
# calculating gradient penalty
gp = get_gp(real, fake.detach(), crit, alpha)
# calculating loss
crit_loss = crit_fake_pred.mean() - crit_real_pred.mean() + gp
#.item - taking only the number from the tensor
mean_crit_loss += crit_loss.item() / crit_cycles
# optimizer backpropagation
crit_loss.backward(retain_graph=True)
crit_opt.step()
# list of losses values
crit_losses += [mean_crit_loss]
## Generator
# zeroing gradient of the optimizer
gen_opt.zero_grad()
# creating noise 128 x 200
noise = gen_noise(cur_bs, z_dim)
# passing noise through generator
fake = gen(noise)
# passing them through critic
crit_fake_pred = crit(fake)
# negative of the pred of the critic
gen_loss = -crit_fake_pred.mean()
# backpropagation
gen_loss.backward()
# updating the parameters of the generator
gen_opt.step()
gen_losses+=[gen_loss.item()]
## Statistics
if (wandb==1):
wandb.log(
{'Epoch':epoch,
'Step':cur_step,
'Critic loss':mean_crit_loss,
'Gen loss':gen_loss,
}
)
if cur_step % save_step == 0 and cur_step > 0:
print('Saving checkpoint:', cur_step, save_step)
# best to save the files with the different names fe. nr of epoch
save_checkpoint('latest')
if (cur_step % show_step == 0 and cur_step > 0):
show(fake, wandbactivation=1, name='fake')
show(real, wandbactivation=1, name='real')
gen_mean = sum(gen_losses[-show_step:]) / show_step
crit_mean = sum(crit_losses[-show_step:]) / show_step
print(f'Epoch: {epoch}, step: {cur_step}, Generator loss: {gen_loss}, Critic loss: {crit_loss}')
plt.plot(range(len(gen_losses)),
torch.Tensor(gen_losses),
label='Generator loss')
plt.plot(range(len(crit_losses)),
torch.Tensor(crit_losses),
label='Critic loss')
plt.ylim(-200,200)
plt.legend()
plt.show()
cur_step += 1
0%| | 0/391 [00:00<?, ?it/s]
/usr/local/lib/python3.10/dist-packages/torch/autograd/graph.py:744: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at ../aten/src/ATen/native/cudnn/Conv_v8.cpp:919.) return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
Saving checkpoint: 35 35 Saved checkpoint
Epoch: 0, step: 35, Generator loss: 15.25752067565918, Critic loss: -21.56631088256836
Saving checkpoint: 70 35 Saved checkpoint
Epoch: 0, step: 70, Generator loss: 25.56017303466797, Critic loss: -29.439315795898438
Saving checkpoint: 105 35 Saved checkpoint
Epoch: 0, step: 105, Generator loss: 24.86949920654297, Critic loss: -23.476028442382812
Saving checkpoint: 140 35 Saved checkpoint
Epoch: 0, step: 140, Generator loss: 26.923324584960938, Critic loss: -22.520244598388672
Saving checkpoint: 175 35 Saved checkpoint
Epoch: 0, step: 175, Generator loss: 25.521202087402344, Critic loss: -18.931087493896484
Saving checkpoint: 210 35 Saved checkpoint
Epoch: 0, step: 210, Generator loss: 20.612937927246094, Critic loss: -19.237375259399414
Saving checkpoint: 245 35 Saved checkpoint
Epoch: 0, step: 245, Generator loss: 15.082696914672852, Critic loss: -14.787302017211914
Saving checkpoint: 280 35 Saved checkpoint
Epoch: 0, step: 280, Generator loss: 20.040992736816406, Critic loss: -13.67352294921875
Saving checkpoint: 315 35 Saved checkpoint
Epoch: 0, step: 315, Generator loss: 16.89620018005371, Critic loss: -14.186975479125977
Saving checkpoint: 350 35 Saved checkpoint
Epoch: 0, step: 350, Generator loss: 19.051918029785156, Critic loss: -12.683650016784668
Saving checkpoint: 385 35 Saved checkpoint
Epoch: 0, step: 385, Generator loss: 16.979228973388672, Critic loss: -14.533666610717773
/usr/local/lib/python3.10/dist-packages/torch/autograd/graph.py:744: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at ../aten/src/ATen/native/cudnn/Conv_v8.cpp:919.) return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
0%| | 0/391 [00:00<?, ?it/s]
Saving checkpoint: 420 35 Saved checkpoint
Epoch: 1, step: 420, Generator loss: 17.686763763427734, Critic loss: -14.1609525680542
Saving checkpoint: 455 35 Saved checkpoint
Epoch: 1, step: 455, Generator loss: 18.91667938232422, Critic loss: -12.47118091583252
Saving checkpoint: 490 35 Saved checkpoint
Epoch: 1, step: 490, Generator loss: 18.239500045776367, Critic loss: -10.373126029968262
Saving checkpoint: 525 35 Saved checkpoint
Epoch: 1, step: 525, Generator loss: 16.01999282836914, Critic loss: -11.49423885345459
Saving checkpoint: 560 35 Saved checkpoint
Epoch: 1, step: 560, Generator loss: 15.156608581542969, Critic loss: -12.317692756652832
Saving checkpoint: 595 35 Saved checkpoint
Epoch: 1, step: 595, Generator loss: 16.044208526611328, Critic loss: -10.365501403808594
Saving checkpoint: 630 35 Saved checkpoint
Epoch: 1, step: 630, Generator loss: 18.690975189208984, Critic loss: -10.940778732299805
Saving checkpoint: 665 35 Saved checkpoint
Epoch: 1, step: 665, Generator loss: 17.979543685913086, Critic loss: -10.223814010620117
Saving checkpoint: 700 35 Saved checkpoint
Epoch: 1, step: 700, Generator loss: 14.435104370117188, Critic loss: -13.017997741699219
Saving checkpoint: 735 35 Saved checkpoint
Epoch: 1, step: 735, Generator loss: 19.215503692626953, Critic loss: -8.568474769592285
Saving checkpoint: 770 35 Saved checkpoint
Epoch: 1, step: 770, Generator loss: 22.676485061645508, Critic loss: -13.234884262084961
0%| | 0/391 [00:00<?, ?it/s]
Saving checkpoint: 805 35 Saved checkpoint
Epoch: 2, step: 805, Generator loss: 20.212373733520508, Critic loss: -9.45648193359375
Saving checkpoint: 840 35 Saved checkpoint
Epoch: 2, step: 840, Generator loss: 19.213638305664062, Critic loss: -9.419897079467773
Saving checkpoint: 875 35 Saved checkpoint
Epoch: 2, step: 875, Generator loss: 17.558929443359375, Critic loss: -9.927236557006836
Saving checkpoint: 910 35 Saved checkpoint
Epoch: 2, step: 910, Generator loss: 23.620121002197266, Critic loss: -10.862022399902344
Saving checkpoint: 945 35 Saved checkpoint
Epoch: 2, step: 945, Generator loss: 18.534482955932617, Critic loss: -9.245712280273438
Saving checkpoint: 980 35 Saved checkpoint
Epoch: 2, step: 980, Generator loss: 23.425382614135742, Critic loss: -10.018635749816895
Saving checkpoint: 1015 35 Saved checkpoint
Epoch: 2, step: 1015, Generator loss: 20.67441177368164, Critic loss: -9.165563583374023
Saving checkpoint: 1050 35 Saved checkpoint
Epoch: 2, step: 1050, Generator loss: 19.643081665039062, Critic loss: -8.507790565490723
Saving checkpoint: 1085 35 Saved checkpoint
Epoch: 2, step: 1085, Generator loss: 20.506515502929688, Critic loss: -10.645575523376465
Saving checkpoint: 1120 35 Saved checkpoint
Epoch: 2, step: 1120, Generator loss: 24.903501510620117, Critic loss: -8.776479721069336
Saving checkpoint: 1155 35 Saved checkpoint
Epoch: 2, step: 1155, Generator loss: 25.28638458251953, Critic loss: -10.320352554321289
0%| | 0/391 [00:00<?, ?it/s]
Saving checkpoint: 1190 35 Saved checkpoint
Epoch: 3, step: 1190, Generator loss: 16.862964630126953, Critic loss: -9.55805778503418
Saving checkpoint: 1225 35 Saved checkpoint
Epoch: 3, step: 1225, Generator loss: 16.591989517211914, Critic loss: -8.997180938720703
Saving checkpoint: 1260 35 Saved checkpoint
Epoch: 3, step: 1260, Generator loss: 20.295799255371094, Critic loss: -8.942169189453125
Saving checkpoint: 1295 35 Saved checkpoint
Epoch: 3, step: 1295, Generator loss: 16.82471466064453, Critic loss: -8.582033157348633
Saving checkpoint: 1330 35 Saved checkpoint
Epoch: 3, step: 1330, Generator loss: 19.550825119018555, Critic loss: -8.895916938781738
Saving checkpoint: 1365 35 Saved checkpoint
Epoch: 3, step: 1365, Generator loss: 19.583209991455078, Critic loss: -8.709506034851074
Saving checkpoint: 1400 35 Saved checkpoint
Epoch: 3, step: 1400, Generator loss: 18.479915618896484, Critic loss: -9.812521934509277
Saving checkpoint: 1435 35 Saved checkpoint
Epoch: 3, step: 1435, Generator loss: 18.085033416748047, Critic loss: -8.537590026855469
Saving checkpoint: 1470 35 Saved checkpoint
Epoch: 3, step: 1470, Generator loss: 15.994598388671875, Critic loss: -8.235672950744629
Saving checkpoint: 1505 35 Saved checkpoint
Epoch: 3, step: 1505, Generator loss: 17.538623809814453, Critic loss: -9.286766052246094
Saving checkpoint: 1540 35 Saved checkpoint
Epoch: 3, step: 1540, Generator loss: 14.294757843017578, Critic loss: -10.109136581420898
0%| | 0/391 [00:00<?, ?it/s]
--------------------------------------------------------------------------- KeyboardInterrupt Traceback (most recent call last) <ipython-input-18-d641a5009044> in <cell line: 3>() 2 3 for epoch in range(n_epochs): ----> 4 for real, _ in tqdm(dataloader): 5 cur_bs = len(real) # 128 6 real = real.to(device) /usr/local/lib/python3.10/dist-packages/tqdm/notebook.py in __iter__(self) 248 try: 249 it = super().__iter__() --> 250 for obj in it: 251 # return super(tqdm...) will not catch exception 252 yield obj /usr/local/lib/python3.10/dist-packages/tqdm/std.py in __iter__(self) 1179 1180 try: -> 1181 for obj in iterable: 1182 yield obj 1183 # Update and possibly print the progressbar. /usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py in __next__(self) 629 # TODO(https://github.com/pytorch/pytorch/issues/76750) 630 self._reset() # type: ignore[call-arg] --> 631 data = self._next_data() 632 self._num_yielded += 1 633 if self._dataset_kind == _DatasetKind.Iterable and \ /usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py in _next_data(self) 673 def _next_data(self): 674 index = self._next_index() # may raise StopIteration --> 675 data = self._dataset_fetcher.fetch(index) # may raise StopIteration 676 if self._pin_memory: 677 data = _utils.pin_memory.pin_memory(data, self._pin_memory_device) /usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index) 49 data = self.dataset.__getitems__(possibly_batched_index) 50 else: ---> 51 data = [self.dataset[idx] for idx in possibly_batched_index] 52 else: 53 data = self.dataset[possibly_batched_index] /usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/fetch.py in <listcomp>(.0) 49 data = self.dataset.__getitems__(possibly_batched_index) 50 else: ---> 51 data = [self.dataset[idx] for idx in possibly_batched_index] 52 else: 53 data = self.dataset[possibly_batched_index] <ipython-input-12-d25a227d10ee> in __getitem__(self, idx) 22 # open image idx 23 data = PIL.Image.open(self.items[idx]).convert('RGB') # size of the image: fe. 1278, 121 ---> 24 data = np.asarray(torchvision.transforms.Resize(self.sizes)(data)) # 128 x 128 x 3 25 data = np.transpose(data, (2,0,1)).astype(np.float32, copy=False) # 3 x 128 x 128 26 # from np to tensor for training, div = standarizing /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs) 1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1531 else: -> 1532 return self._call_impl(*args, **kwargs) 1533 1534 def _call_impl(self, *args, **kwargs): /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs) 1539 or _global_backward_pre_hooks or _global_backward_hooks 1540 or _global_forward_hooks or _global_forward_pre_hooks): -> 1541 return forward_call(*args, **kwargs) 1542 1543 try: /usr/local/lib/python3.10/dist-packages/torchvision/transforms/transforms.py in forward(self, img) 352 PIL Image or Tensor: Rescaled image. 353 """ --> 354 return F.resize(img, self.size, self.interpolation, self.max_size, self.antialias) 355 356 def __repr__(self) -> str: /usr/local/lib/python3.10/dist-packages/torchvision/transforms/functional.py in resize(img, size, interpolation, max_size, antialias) 466 warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.") 467 pil_interpolation = pil_modes_mapping[interpolation] --> 468 return F_pil.resize(img, size=output_size, interpolation=pil_interpolation) 469 470 return F_t.resize(img, size=output_size, interpolation=interpolation.value, antialias=antialias) /usr/local/lib/python3.10/dist-packages/torchvision/transforms/_functional_pil.py in resize(img, size, interpolation) 248 raise TypeError(f"Got inappropriate size arg: {size}") 249 --> 250 return img.resize(tuple(size[::-1]), interpolation) 251 252 /usr/local/lib/python3.10/dist-packages/PIL/Image.py in resize(self, size, resample, box, reducing_gap) 2190 ) 2191 -> 2192 return self._new(self.im.resize(size, resample, box)) 2193 2194 def reduce(self, factor, box=None): KeyboardInterrupt:
10000 / 128 = 78.125 - 79 steps per epoch
In [ ]:
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()